/*************************************************************************
 * The contents of this file are subject to the MYRICOM MYRINET          *
 * EXPRESS (MX) NETWORKING SOFTWARE AND DOCUMENTATION LICENSE (the       *
 * "License"); User may not use this file except in compliance with the  *
 * License.  The full text of the License can found in LICENSE.TXT       *
 *                                                                       *
 * Software distributed under the License is distributed on an "AS IS"   *
 * basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See  *
 * the License for the specific language governing rights and            *
 * limitations under the License.                                        *
 *                                                                       *
 * Copyright 2003 - 2004 by Myricom, Inc.  All rights reserved.          *
 *************************************************************************/

#include "mx_auto_config.h"
#include "mxsmpi.h"
#include "mx_debug.h"
#include "mx__debug_dump.h"
#include "mx__error.h"

#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/utsname.h>
#include <sys/time.h>

struct mxsmpi_comm mxsmpi_comm_world;
struct mxsmpi_comm mxsmpi_comm_self;

int MXSMPI_Init(int *argc, char ***argv)
{
  int cc, nb_host, i, eid, eid2;
  int nic_index;
  uint64_t nic_id;
  struct utsname uts;
  char *np_env, *machfile, *eid_env, *rank_env, *nic_env;
  mx_endpoint_t ep;
  mx_endpoint_addr_t addr;
  struct mxsmpi_peer *peers;

  setvbuf(stdout, 0, _IOLBF, BUFSIZ);
  uname(&uts);

  np_env = getenv("MXSMPI_NP");
  rank_env = getenv("MXSMPI_RANK");
  nic_env = getenv("MXSMPI_NIC");
  nic_index = nic_env ? atoi(nic_env) : 0;
  eid_env = getenv("MXSMPI_EID");
  eid = eid_env ? atoi(eid_env) : 0;
  rank_env = getenv("MXSMPI_RANK");
  MXSMPI_COMM_WORLD->size = np_env ? atoi(np_env) : 1;
  MXSMPI_COMM_WORLD->rank = -1;

  peers = calloc(MXSMPI_COMM_WORLD->size, sizeof(peers[0]));
  MXSMPI_COMM_WORLD->peers = peers;
  machfile  = getenv("MXSMPI_FILE");
  if (!machfile) {
    nb_host = 1;
    strcpy(peers[0].name,uts.nodename);
    strcat(peers[0].name,":0");
  } else {
    FILE *f = fopen(machfile,"r");
    if (!f) {
      perror(machfile);
      exit(1);
    }
    for (nb_host=0;nb_host < MXSMPI_COMM_WORLD->size;nb_host++) {
      if (!fgets(peers[nb_host].name,sizeof(peers[nb_host].name)-2, f))
	break;
      /* remove end-of-line */
      peers[nb_host].name[strlen(peers[nb_host].name) -1] = 0;
      if (!strchr(peers[nb_host].name,':')) {
	strcat(peers[nb_host].name,":0");
      }
    }
    assert(nb_host > 0);
    if (ferror(f)) {
      perror(machfile);
      exit(1);
    }
    fclose(f);
  }

  mx_init();
  cc = mx_open_endpoint(nic_index, eid, 10, NULL, 0, &ep);
  if (cc != MX_SUCCESS) {
    fprintf(stderr,"Cannot open mx endpoint:%s(%d)\n", mx_strerror(cc),cc);
    exit(1);
  }
  mx_get_endpoint_addr(ep, &addr);
  mx_decompose_endpoint_addr(addr, &nic_id, &eid2);
  sleep(2);
  assert(eid == eid2);
  mx_set_error_handler(MX_ERRORS_RETURN);
  for (i=0;i< MXSMPI_COMM_WORLD->size;i++) {
    int j;
    if (i >= nb_host)
      strcpy(peers[i].name,peers[i%nb_host].name);
    cc = mx_hostname_to_nic_id(peers[i].name,&peers[i].nic);
    if (cc != MX_SUCCESS) {
      fprintf(stderr,"%s:cannot convert %s to a nic id\n", uts.nodename, peers[i].name);
      exit(1);
    }
   if (peers[i].nic == nic_id && i / nb_host == eid) {
      assert(MXSMPI_COMM_WORLD->rank = -1);
      MXSMPI_COMM_WORLD->rank = i;
      assert(!rank_env || atoi(rank_env) == i);
    }
    for (j=0;j<5;j++) {
      int old_flag = mx__error_set_verbose(0);
      cc = mx_connect(ep, peers[i].nic, i / nb_host, 10, 10000, &peers[i].addr);
      mx__error_set_verbose(old_flag);
      if (cc == MX_SUCCESS)
	break;
      usleep((j+1)*(j+1)*100000);
    }
    if (cc != MX_SUCCESS) {
      /* one last time with default verbosity */
      cc = mx_connect(ep, peers[i].nic, i / nb_host, 10, 10000, &peers[i].addr);
      if (cc != MX_SUCCESS) {
	fprintf(stderr,"%s:cannot connect to %s/%d : %s\n", 
		uts.nodename, peers[i].name, i / nb_host, mx_strerror(cc));
	exit(1);
      }
    }
    mx__set_app_rank(ep, &peers[i].addr, i);
  }
  mx_set_error_handler(MX_ERRORS_ARE_FATAL);
  assert (MXSMPI_COMM_WORLD->rank >= 0);
  MXSMPI_COMM_WORLD->ep = ep;
  MXSMPI_COMM_WORLD->id = 8;

  MXSMPI_COMM_SELF->ep = ep;
  MXSMPI_COMM_SELF->peers = MXSMPI_COMM_WORLD->peers + MXSMPI_COMM_WORLD->rank;
  MXSMPI_COMM_SELF->size = 1;
  MXSMPI_COMM_SELF->rank = 0;
  MXSMPI_COMM_SELF->id = 10;
  return MXSMPI_SUCCESS;
}

int MXSMPI_Finalize(void)
{
  MXSMPI_Barrier(MXSMPI_COMM_WORLD);
  free(MXSMPI_COMM_WORLD->peers);
  MXSMPI_COMM_WORLD->peers = 0;
  mx_close_endpoint(MXSMPI_COMM_WORLD->ep);
  MXSMPI_COMM_WORLD->ep = 0;
  mx_finalize();
  return MXSMPI_SUCCESS;
}

typedef void (*comb_func)(void *a,void*b,int count);

#define mpiop_do_sum(a,b) (b) += (a)
#define mpiop_do_max(a,b) (b) = (a) > (b) ? (a) : (b)
#define mpiop_do_min(a,b) (b) = (a) > (b) ? (b) : (a)

#define COMB(name,type,op)			\
static void name(void *ap,void *bp,int count)	\
{						\
  int i;					\
  type *a = ap, *b = bp;			\
  for (i=0;i<count/sizeof(type);i++) {			\
    op(*a,*b);					\
  }						\
}

COMB(u8_sum,uint8_t,mpiop_do_sum)
COMB(u16_sum,uint32_t,mpiop_do_sum)
COMB(u32_sum,uint16_t,mpiop_do_sum)
COMB(float_sum,float,mpiop_do_sum)
COMB(double_sum,double,mpiop_do_sum)

void noop_op(void *a, void *b, int count)
{
}


COMB(u8_max,uint8_t,mpiop_do_max)
COMB(u16_max,uint32_t,mpiop_do_max)
COMB(u32_max,uint16_t,mpiop_do_max)
COMB(s8_max,int8_t,mpiop_do_max)
COMB(s16_max,int32_t,mpiop_do_max)
COMB(s32_max,int16_t,mpiop_do_max)
COMB(float_max,float,mpiop_do_max)
COMB(double_max,double,mpiop_do_max)

COMB(u8_min,uint8_t,mpiop_do_min)
COMB(u16_min,uint32_t,mpiop_do_min)
COMB(u32_min,uint16_t,mpiop_do_min)
COMB(s8_min,int8_t,mpiop_do_min)
COMB(s16_min,int32_t,mpiop_do_min)
COMB(s32_min,int16_t,mpiop_do_min)
COMB(float_min,float,mpiop_do_min)
COMB(double_min,double,mpiop_do_min)

static comb_func op_array[MXSMPI_NB_OPS][MXSMPI_NB_TYPES] = {
  { noop_op, noop_op, noop_op, noop_op, noop_op, noop_op, noop_op, noop_op, noop_op, noop_op },
  { u8_sum, u16_sum, u32_sum, u32_sum, u8_sum, u16_sum, u32_sum, u32_sum, float_sum, double_sum },
  { s8_max, s16_max, s32_max, s32_max, u8_max, u16_max, u32_max, u32_max, float_max, double_max },
  { s8_min, s16_min, s32_min, s32_min, u8_min, u16_min, u32_min, u32_min, float_min, double_min },
};

static void do_copy(void *a, void *b, int count)
{
  memcpy(a,b,count);
}

void do_coll(void* sendbuf, void* recvbuf, int count, int tag, comb_func func, int root, MXSMPI_Comm comm)
{
  struct mxsmpi_peer *nodes = comm->peers;
  mx_request_t req,req2;
  mx_segment_t r_seg, s_seg;
  mx_status_t s;
  uint32_t res;
  int next,prev;

  r_seg.segment_ptr = recvbuf;
  r_seg.segment_length = count;
  s_seg.segment_ptr = sendbuf;
  s_seg.segment_length = count;
  next = (comm->rank == comm->size - 1) ? 0 : comm->rank + 1;
  prev = (comm->rank == 0) ? comm->size - 1 : comm->rank - 1;
  if (comm->rank == root) {
    mx_isend(comm->ep, &s_seg, 1, nodes[next].addr, MXSMPI_MATCH(comm->id+1,comm->rank,tag), NULL, &req);
    mx_irecv(comm->ep, &r_seg, 1, MXSMPI_MATCH(comm->id+1,prev,tag), MX_MATCH_MASK_NONE, NULL, &req2);
    mx_wait(comm->ep,&req,MX_INFINITE,&s,&res);
    mx_wait(comm->ep,&req2,MX_INFINITE,&s,&res);
  } else {
    mx_irecv(comm->ep, &r_seg, 1, MXSMPI_MATCH(comm->id+1,prev,tag), MX_MATCH_MASK_NONE, NULL, &req);
    mx_wait(comm->ep,&req,MX_INFINITE,&s,&res);
    func(sendbuf,recvbuf,count);
    mx_isend(comm->ep, &r_seg, 1, nodes[next].addr, MXSMPI_MATCH(comm->id+1,comm->rank,tag), NULL, &req);
    mx_wait(comm->ep,&req,MX_INFINITE,&s,&res);
  }
}

int MXSMPI_Barrier(MXSMPI_Comm comm)
{
  if (comm->size > 1) {
    do_coll(0,0,0,0,noop_op,0,comm);
    do_coll(0,0,0,1,noop_op,0,comm);
  }
  return MXSMPI_SUCCESS;
}

int MXSMPI_Bcast(void* buffer, int count, MXSMPI_Datatype datatype, int root, MXSMPI_Comm comm )
{
  if (comm->size > 1)
    do_coll(buffer, buffer, count*(uint32_t)datatype, 2, do_copy, root, comm);
  return MXSMPI_SUCCESS;
}

int MXSMPI_Reduce(void* sendbuf, void* recvbuf, int count, MXSMPI_Datatype datatype, MXSMPI_Op op, int root, MXSMPI_Comm comm)
{
  comb_func func = op_array[op][(uint16_t)(datatype >> 32)];
  do_coll(sendbuf, recvbuf, count*(uint32_t)datatype, 3, func, root, comm);
  return MXSMPI_SUCCESS;
}

int MXSMPI_Allreduce(void* sendbuf, void* recvbuf, int count, MXSMPI_Datatype datatype, MXSMPI_Op op, MXSMPI_Comm comm)
{
  comb_func func = op_array[op][(uint16_t)(datatype >> 32)];
  do_coll(sendbuf, recvbuf, count*(uint32_t)datatype, 4, func, 0, comm);
  MXSMPI_Bcast(recvbuf, count, datatype, 0, comm);
  return MXSMPI_SUCCESS;
}

int MXSMPI_Gather(void* sendbuf, int sendcount, MXSMPI_Datatype sendtype, 
		  void* recvbuf, int recvcount, MXSMPI_Datatype recvtype, 
		  int root, MXSMPI_Comm comm)
{
  uint32_t length = sendcount * (uint32_t)sendtype;
  assert(length == recvcount * (uint32_t)recvtype);

  if (comm->rank == root) {
    int i;
    mx_segment_t r_seg;
    mx_request_t *req = malloc(comm->size*sizeof(req[0]));

    for (i=0;i<comm->size;i++) {
      if (i == comm->rank)
	continue;
      r_seg.segment_ptr = (char*)recvbuf + i *length;
      r_seg.segment_length = length;
      mx_irecv(comm->ep, &r_seg, 1, MXSMPI_MATCH(comm->id+1,i,5), MX_MATCH_MASK_NONE, NULL, &req[i]);
    }
    for (i=0;i<comm->size;i++) {
      mx_status_t s;
      uint32_t res;
      if (i == comm->rank)
	continue;
      mx_wait(comm->ep, &req[i], MX_INFINITE, &s, &res);
      assert(s.xfer_length == length);
    }
    free(req);
    memcpy((char*)recvbuf + comm->rank * length, sendbuf, length);
  } else {
    struct mxsmpi_peer *root_node = comm->peers+root;
    mx_segment_t s_seg;
    mx_request_t req;
    mx_status_t s;
    uint32_t res;
    s_seg.segment_ptr = sendbuf;
    s_seg.segment_length = length;
    mx_isend(comm->ep, &s_seg, 1, root_node->addr, MXSMPI_MATCH(comm->id+1,comm->rank,5), NULL, &req);
    mx_wait(comm->ep,&req,MX_INFINITE,&s,&res);
  }
  return MXSMPI_SUCCESS;
}

int MXSMPI_Scatter(void* sendbuf, int sendcount, MXSMPI_Datatype sendtype, 
		   void* recvbuf, int recvcount, MXSMPI_Datatype recvtype, 
		   int root, MXSMPI_Comm comm)
{
  uint32_t length = sendcount * (uint32_t)sendtype;
  assert(length == recvcount * (uint32_t)recvtype);

  if (comm->rank == root) {
    int i;
    struct mxsmpi_peer *nodes = comm->peers;
    mx_segment_t s_seg;
    mx_request_t *req = malloc(comm->size*sizeof(req[0]));

    for (i=0;i<comm->size;i++) {
      if (i == comm->rank)
	continue;
      s_seg.segment_ptr = (char*)sendbuf + i *length;
      s_seg.segment_length = length;
      mx_isend(comm->ep, &s_seg, 1, nodes[i].addr, MXSMPI_MATCH(comm->id+1,comm->rank,6), NULL, &req[i]);
    }
    for (i=0;i<comm->size;i++) {
      mx_status_t s;
      uint32_t res;
      if (i == comm->rank)
	continue;
      mx_wait(comm->ep, &req[i], MX_INFINITE, &s, &res);
    }
    free(req);
    memcpy(recvbuf, (char*)sendbuf + comm->rank * length, length);
  } else {
    mx_segment_t r_seg;
    mx_request_t req;
    mx_status_t s;
    uint32_t res;

    r_seg.segment_ptr = recvbuf;
    r_seg.segment_length = length;
    mx_irecv(comm->ep, &r_seg, 1, MXSMPI_MATCH(comm->id+1,root,6), MX_MATCH_MASK_NONE, NULL, &req);
    mx_wait(comm->ep,&req,MX_INFINITE,&s,&res);
  }
  return MXSMPI_SUCCESS;
}

double MXSMPI_Wtime(void)
{
  struct timeval tv;
  gettimeofday(&tv,NULL);
  return tv.tv_sec + tv.tv_usec * 1e-6;
}

double MXSMPI_Wtick(void)
{
  return 1e-5;
}

int MXSMPI_Abort(MXSMPI_Comm comm, int code)
{
  fprintf(stderr,"MXSMPI_Abort:Process rank %d aborted!!!\n", comm->rank);
  exit(code);
}

int MXSMPI_Waitall(int incount, MXSMPI_Request in_req[], MXSMPI_Status status[])
{
  int i;
  for (i=0;i<incount;i++) {
    if (in_req[i]) {
      MXSMPI_Wait(in_req+i,status+i);
    }
  }
  return MXSMPI_SUCCESS;
}

int MXSMPI_Testsome(int incount, MXSMPI_Request in_req[], int *outcount, int idx[], MXSMPI_Status status[])
{
  int i,o;
  int flag;
  int seen = 0;
  for (i=0,o=0;i<incount;i++) {
    if (in_req[i]) {
      seen = 1;
      MXSMPI_Test(in_req+i, &flag, status+o);
      if (flag) {
	assert(in_req[i] == 0);
	idx[o] = i;
	o +=1;
      }
    }
  }
  assert(seen);
  *outcount = o;
  return MXSMPI_SUCCESS;
}

int MXSMPI_Waitsome(int incount, MXSMPI_Request in_req[], int *outcount, int idx[], MXSMPI_Status status[])
{
  do {
    MXSMPI_Testsome(incount, in_req, outcount, idx, status);
  } while (*outcount == 0);
  return MXSMPI_SUCCESS;
}


int MXSMPI_Type_size(MXSMPI_Datatype datatype, int *size)
{
  *size = (uint32_t)datatype;
  return MXSMPI_SUCCESS;
}

int MXSMPI_Get_processor_name(char *name, int *resultlen)
{
  strcpy(name, MXSMPI_COMM_WORLD->peers[MXSMPI_COMM_WORLD->rank].name);
  *resultlen = strlen(name);
  return MXSMPI_SUCCESS;
}


int MXSMPI_Errhandler_create(MXSMPI_Handler_function *function, MXSMPI_Errhandler *errhandler)
{
  *errhandler = function;
  return MXSMPI_SUCCESS;
}

int MXSMPI_Errhandler_set(MXSMPI_Comm comm, MXSMPI_Errhandler errhandler)
{
  mx_fatal("MXSMPI_Errhandler_set not implemented");
  return MXSMPI_SUCCESS;
}

int MXSMPI_Errhandler_free(MXSMPI_Errhandler *errhandler)
{
  *errhandler = 0;
  return MXSMPI_SUCCESS;
}

int MXSMPI_Error_string(int errorcode, char *string, int *resultlen)
{
  strcpy(string,errorcode == MXSMPI_SUCCESS ? "Success" : "Unknwown error");
  *resultlen = strlen(string);
  return MXSMPI_SUCCESS;
}


int MXSMPI_Type_contiguous(int count, MXSMPI_Datatype oldtype, MXSMPI_Datatype *newtype)
{
  *newtype = ((uint32_t)oldtype*count) | (oldtype & 0xffffffff00000000ULL);
  return MXSMPI_SUCCESS;
}

int MXSMPI_Type_commit(MXSMPI_Datatype * type)
{
  return MXSMPI_SUCCESS;
}

int MXSMPI_Type_free(MXSMPI_Datatype * type)
{
  return MXSMPI_SUCCESS;
}


int MXSMPI_Comm_dup(MXSMPI_Comm comm, MXSMPI_Comm *dup)
{
  MXSMPI_Comm new;
  static int id_cnt;
  new = malloc(sizeof(*new) + sizeof(struct mxsmpi_peer) * comm->size);
  *new = *comm;
  new->peers = (void*)(new + 1);
  id_cnt += 2;
  new->id = id_cnt;
  assert(new->id < 32767);
  memcpy(new->peers, comm->peers, sizeof(new->peers[0]) * comm->size);
  *dup = new;
  return MXSMPI_SUCCESS;
}
